import cv2
import torch
from torchvision import transforms
from PIL import Image, ImageOps

import numpy as np
import scipy.misc as misc
import os
import glob
import csv
from seg_system.application_config import ApplicationConfig

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)


# torch.backends.cudnn.determinstic = True
# torch.backends.cudnn.benchmark = False

m_stage2 = None  # 减少模型多次导入


def load_net():
    global m_stage2
    if m_stage2:
        return m_stage2
    # torch.nn.Module.dump_patches = True
    if ApplicationConfig.SystemConfig.GRADE_USE_CUDA:
        net = torch.load("./checkpoint/DeepGrading18-fold5-0.8564.pth",
                         map_location=ApplicationConfig.SystemConfig.DEVICE)
        m_stage2 = net
    else:
        net = torch.load("./checkpoint/DeepGrading18-fold5-0.8564.pth",
                         map_location=torch.device('cpu'))
        m_stage2 = net
    return net


def predict(img, seg, roi):
    net = load_net()
    net.eval()
    preds, gts = [], []
    transform_test = transforms.Compose([
        transforms.Resize((304, 304)),
        transforms.ToTensor(),
        transforms.Normalize(mean=0.339, std=0.138),
    ])
    transform_roi = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize(mean=0.339, std=0.138)
    ])
    with torch.no_grad():
        img = Image.fromarray(img)
        seg = Image.fromarray(seg)
        roi = Image.fromarray(roi)
        img = transform_test(img)
        img = img.unsqueeze_(0)
        seg = transform_test(seg)
        seg = seg.unsqueeze_(0)
        roi = transform_roi(roi)
        roi = roi.unsqueeze_(0)
        image = torch.cat((img, seg, img), dim=1)  # .cuda()
        roi = torch.cat((roi, roi), dim=1)  # .cuda()
        if ApplicationConfig.SystemConfig.GRADE_USE_CUDA:
            image = image.to(ApplicationConfig.SystemConfig.DEVICE)
            roi = roi.to(ApplicationConfig.SystemConfig.DEVICE)
        x1, x2, roi, predictions = net(image, roi)
        probs = torch.softmax(predictions.detach(), dim=1)
        probs = probs.data.cpu().numpy()
        predictions = torch.argmax(predictions, dim=1)
        predictions = predictions.data.cpu().numpy()
        return predictions, probs

# if __name__ == '__main__':
#     '''
#     Note:
#         You should only input a numpy array of 'img', 'seg', 'roi' to the grading function 'predict'.
#         The following shows an example of how to use this function
#     '''
#     img_file = "./test_image/original_img.jpg"
#     seg_file = "./test_image/segmentation.png"
#     roi_file = "./test_image/roi.png"
#     img = cv2.imread(img_file, flags=-1)
#     seg = cv2.imread(seg_file, flags=-1)
#     roi = cv2.imread(roi_file, flags=-1)
#     # the image opened by 'cv2' holds in 'numpy arry'
#     predictions, probs = predict(img, seg, roi)
#     print(predictions, probs)
#
#     # The predictions are as follows:
#     #           [1] [[0.00171472 0.66045445 0.33193085 0.00589993]]
#     # [1] is the predicted tortuosity level, you can obtain the value of the level by using:
#     #           level = predictions[0]
#     # [[0.00171472 0.66045445 0.33193085 0.00589993]] are represents the predicted probabilities of each level,
#     # the number of torutsoity level is 4, you can get the probability of each tortuosity level by using:
#     #           probabilities=probs[0]
#     # The tortuosity from level 1 to level 4 are 0.0017, 0.6604, 0.3319, and 0.0059, respectively.
